
package com.jstarcraft.ai.jsat.regression;

import java.util.ArrayList;
import java.util.List;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.DataStore;
import com.jstarcraft.ai.jsat.RowMajorStore;
import com.jstarcraft.ai.jsat.classifiers.CategoricalData;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.classifiers.DataPointPair;
import com.jstarcraft.ai.jsat.linear.DenseVector;
import com.jstarcraft.ai.jsat.linear.IndexValue;
import com.jstarcraft.ai.jsat.linear.SparseVector;
import com.jstarcraft.ai.jsat.linear.Vec;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;

/**
 * A RegressionDataSet is a data set specifically for the task of performing
 * regression. Each data point is paired with s double value that indicates its
 * true regression value. An example of a regression problem would be mapping
 * the inputs of a function to its outputs, and attempting to learn the function
 * from the samples.
 * 
 * @author Edward Raff
 */
public class RegressionDataSet extends DataSet<RegressionDataSet> {

    protected DoubleArrayList targets;

    /**
     * Creates a new empty data set for regression
     * 
     * @param numerical  the number of numerical attributes that will be used,
     *                   excluding the regression value
     * @param categories an array of length equal to the number of categorical
     *                   attributes, each object describing the attribute in
     *                   question
     */
    public RegressionDataSet(int numerical, CategoricalData[] categories) {
        super(numerical, categories);
        targets = new DoubleArrayList();
    }

    /**
     * Creates a new dataset containing the given points paired with their target
     * values. Pairing is determined by the iteration order of each collection.
     *
     * @param datapoints the DataStore that will back this Data Set
     * @param targets    the target values to use
     */
    public RegressionDataSet(DataStore datapoints, List<Double> targets) {
        super(datapoints);
        this.targets = new DoubleArrayList(targets);
    }

    /**
     * Creates a new data set for the given list of data points. The data points
     * will be copied, changes in one will not effect the other.
     * 
     * @param data       the list of data point to create a data set from
     * @param predicting which of the numerical attributes is the regression target.
     *                   Categorical attributes are ignored in the count of
     *                   attributes for this value.
     */
    public RegressionDataSet(List<DataPoint> data, int predicting) {
        super(data.get(0).numNumericalValues() - 1, data.get(0).getCategoricalData());
        // Use the first data point to set up
        DataPoint tmp = data.get(0);
        categories = new CategoricalData[tmp.numCategoricalValues()];
        System.arraycopy(tmp.getCategoricalData(), 0, categories, 0, categories.length);
        targets = new DoubleArrayList(data.size());

        // Fill up data
        for (DataPoint dp : data) {
            Vec origV = dp.getNumericalValues();
            Vec newVec;
            double target = 0;// init to zero to inplicitly handle sparse feature vector case
            if (origV.isSparse())
                newVec = new SparseVector(origV.length() - 1, origV.nnz());
            else
                newVec = new DenseVector(origV.length() - 1);

            for (IndexValue iv : origV)
                if (iv.getIndex() < predicting)
                    newVec.set(iv.getIndex(), iv.getValue());
                else if (iv.getIndex() == predicting)
                    target = iv.getValue();
                else// iv.getIndex() > index
                    newVec.set(iv.getIndex() - 1, iv.getValue());

            DataPoint newDp = new DataPoint(newVec, dp.getCategoricalValues(), categories);

            datapoints.addDataPoint(newDp);
            targets.add(target);
        }
    }

    /**
     * Creates a new regression data set by copying all the data points in the given
     * list. Alterations to this list will not effect this DataSet.
     * 
     * @param list source of data points to copy
     */
    public RegressionDataSet(List<DataPointPair<Double>> list) {
        super(list.get(0).getDataPoint().numNumericalValues(), CategoricalData.copyOf(list.get(0).getDataPoint().getCategoricalData()));
        this.datapoints = new RowMajorStore(numNumerVals, categories);
        this.targets = new DoubleArrayList();
        for (DataPointPair<Double> dpp : list) {
            datapoints.addDataPoint(dpp.getDataPoint());
            targets.add(dpp.getPair());
        }
    }

    private RegressionDataSet() {
        super(new RowMajorStore(1, new CategoricalData[0]));
    }

    public static RegressionDataSet comineAllBut(List<RegressionDataSet> list, int exception) {
        int numer = list.get(exception).getNumNumericalVars();
        CategoricalData[] categories = list.get(exception).getCategories();

        RegressionDataSet rds = new RegressionDataSet(numer, categories);

        // The list of data sets
        for (int i = 0; i < list.size(); i++)
            if (i == exception)
                continue;
            else
                for (int j = 0; j < list.get(i).size(); j++)
                    rds.addDataPoint(list.get(i).getDataPoint(j), list.get(i).getTargetValue(j));

        return rds;
    }

    private static final int[] emptyInt = new int[0];

    /**
     * Creates a new data point with no categorical variables to be added to the
     * data set. The arguments will be used directly, modifying them after will
     * effect the data set.
     * 
     * @param numerical the numerical values for the data point
     * @param val       the taret value
     * @throws IllegalArgumentException if the given values are inconsistent with
     *                                  the data this class stores.
     */
    public void addDataPoint(Vec numerical, double val) {
        addDataPoint(numerical, emptyInt, val);
    }

    /**
     * Creates a new data point to be added to the data set. The arguments will be
     * used directly, modifying them after will effect the data set.
     * 
     * @param numerical  the numerical values for the data point
     * @param categories the categorical values for the data point
     * @param val        the target value to predict
     * @throws IllegalArgumentException if the given values are inconsistent with
     *                                  the data this class stores.
     */
    public void addDataPoint(Vec numerical, int[] categories, double val) {
        if (numerical.length() != numNumerVals)
            throw new RuntimeException("Data point does not contain enough numerical data points");
        if (categories.length != categories.length)
            throw new RuntimeException("Data point does not contain enough categorical data points");

        for (int i = 0; i < categories.length; i++)
            if (!this.categories[i].isValidCategory(categories[i]) && categories[i] >= 0) // >= so that missing values (negative) are allowed
                throw new RuntimeException("Categoriy value given is invalid");

        DataPoint dp = new DataPoint(numerical, categories, this.categories);
        addDataPoint(dp, val);
    }

    /**
     * 
     * @param dp  the data to add
     * @param val the target value for this data point
     */
    public void addDataPoint(DataPoint dp, double val) {
        addDataPoint(dp, val, 1.0);
    }

    /**
     * 
     * @param dp     the data to add
     * @param val    the target value for this data point
     * @param weight the weight for this data point
     */
    public void addDataPoint(DataPoint dp, double val, double weight) {
        if (dp.numNumericalValues() != getNumNumericalVars() || dp.numCategoricalValues() != getNumCategoricalVars())
            throw new RuntimeException("The added data point does not match the number of values and categories for the data set");
        else if (Double.isInfinite(val) || Double.isNaN(val))
            throw new ArithmeticException("Unregressiable value " + val + " given for regression");

        datapoints.addDataPoint(dp);
        targets.add(val);
        setWeight(size() - 1, weight);
    }

    public void addDataPointPair(DataPointPair<Double> pair) {
        addDataPoint(pair.getDataPoint(), pair.getPair());
    }

    /**
     * Returns the i'th data point in the data set paired with its target regressor
     * value. Modifying the DataPointPair will effect the data set.
     * 
     * @param i the index of the data point to obtain
     * @return the i'th DataPOintPair
     */
    public DataPointPair<Double> getDataPointPair(int i) {
        return new DataPointPair<>(getDataPoint(i), targets.getDouble(i));
    }

    /**
     * Returns a new list containing copies of the data points in this data set,
     * paired with their regression target values. MModifications to the list or
     * data points will not effect this data set
     * 
     * @return a list of copies of the data points in this set
     */
    public List<DataPointPair<Double>> getAsDPPList() {
        ArrayList<DataPointPair<Double>> list = new ArrayList<>(size());
        for (int i = 0; i < size(); i++)
            list.add(new DataPointPair<>(getDataPoint(i).clone(), targets.getDouble(i)));
        return list;
    }

    /**
     * Returns a new list containing the data points in this data set, paired with
     * their regression target values. Modifications to the list will not effect the
     * data set, but modifying the points will. For a copy of the points, use the
     * {@link #getAsDPPList() } method.
     * 
     * @return a list of the data points in this set
     */
    public List<DataPointPair<Double>> getDPPList() {
        ArrayList<DataPointPair<Double>> list = new ArrayList<>(size());
        for (int i = 0; i < size(); i++)
            list.add(getDataPointPair(i));
        return list;
    }

    /**
     * Sets the target regression value associated with a given data point
     * 
     * @param i   the index in the data set
     * @param val the new target value
     * @throws ArithmeticException if <tt>val</tt> is infinite or NaN
     */
    public void setTargetValue(int i, double val) {
        if (Double.isInfinite(val) || Double.isNaN(val))
            throw new ArithmeticException("Can not predict a " + val + " value");
        targets.set(i, val);
    }

    @Override
    protected RegressionDataSet getSubset(List<Integer> indicies) {
        RegressionDataSet newData = new RegressionDataSet(numNumerVals, categories);
        for (int i : indicies)
            newData.addDataPoint(getDataPoint(i), getTargetValue(i));
        return newData;
    }

    /**
     * Returns a vector containing the target regression values for each data point.
     * The vector is a copy, and modifications to it will not effect the data set.
     * 
     * @return a vector containing the target values for each data point
     */
    public Vec getTargetValues() {
        DenseVector vals = new DenseVector(size());

        for (int i = 0; i < size(); i++)
            vals.set(i, targets.getDouble(i));

        return vals;
    }

    /**
     * Returns the target regression value for the <tt>i</tt>'th data point in the
     * data set.
     * 
     * @param i the data point to get the regression value of
     * @return the target regression value
     */
    public double getTargetValue(int i) {
        return targets.getDouble(i);
    }

    /**
     * Creates a new data set that uses the given list as its backing list. No
     * copying is done, and changes to this list will be reflected in this data set,
     * and the other way.
     * 
     * @param list the list of datapoint to back a new data set with
     * @return a new data set
     */
    public static RegressionDataSet usingDPPList(List<DataPointPair<Double>> list) {
        return new RegressionDataSet(list);
    }

    @Override
    public RegressionDataSet shallowClone() {
        RegressionDataSet clone = new RegressionDataSet(numNumerVals, categories);
        for (int i = 0; i < size(); i++)
            clone.addDataPointPair(getDataPointPair(i));
        return clone;
    }

    @Override
    public RegressionDataSet emptyClone() {
        return new RegressionDataSet(numNumerVals, categories);
    }

    @Override
    public RegressionDataSet getTwiceShallowClone() {
        return (RegressionDataSet) super.getTwiceShallowClone(); // To change body of generated methods, choose Tools | Templates.
    }
}
